What is Triton
Triton is a programming language and compiler infrastructure built around an SPMD (Single Program, Multiple Data) execution model, currently focused on GPU kernel development.
Conceptually, Triton consists of three layers:
① Triton Language (Python DSL)
The Triton Language resides under python/triton/language and can be divided into two main parts:
- Triton Language Operations These operations can be invoked as functions inside a kernel and are processed through visit_Call in the Python AST.
- Operators inside a kernel (e.g., +, -, *, /), These are typically handled through visit_BinOp during Python AST traversal.
This is the vector add kernel:
import torch
import triton
import triton.language as tl
DEVICE = triton.runtime.driver.active.get_active_torch_device()
@triton.jit
def add_kernel(x_ptr, y_ptr,output_ptr,n_elements,BLOCK_SIZE:tl.constexpr):
pid=tl.program_id(axis=0)
block_start=pid*BLOCK_SIZE
offsets=block_start+tl.arange(0,BLOCK_SIZE)
mask=offsets<n_elements
x=tl.load(x_ptr+offsets,mask=mask)
y=tl.load(y_ptr+offsets,mask=mask)
output=x+y
tl.store(output_ptr+offsets,output,mask=mask)
Explanation
- x_ptr, y_ptr, and output_ptr point to the starting addresses of the input and output GPU arrays. When a Triton kernel is invoked, any PyTorch (or NumPy/CuPy) tensor passed as an argument is automatically converted into a pointer referencing its underlying device memory.
- n_elements specifies the total number of elements in the vector. This parameter allows the kernel to safely determine valid memory bounds during loads and stores.
- BLOCK_SIZE: tl.constexpr, a compile-time constant that defines how many elements each program instance (block) handles. For example, setting BLOCK_SIZE=1024 means a single kernel instance operates on 1024 elements at once using vectorized instructions
- Inside the kernel, tl.program_id(axis=0) returns the unique index of the current program instance along grid dimension 0. For this vector-addition kernel, we launch a 1-dimensional program grid where each program instance handles one block of data.
- We compute the range of offsets from the block’s starting index up to block_start + BLOCK_SIZE - 1. tl.arange(0, BLOCK_SIZE) creates a vector of block-local indices with values: 0, 1, ..., BLOCK_SIZE - 1 By adding block_start to this vector, we obtain the absolute indices in the global array that this kernel instance will process.
- We create a mask boolean vector that marks which offsets are within bounds, i.e., where offset < n_elements. Any index beyond the array length will have a mask value of false (for example, if n_elements is not a multiple of BLOCK_SIZE, the final block will contain some out-of-range offsets). Triton uses this mask to perform memory accesses safely without requiring explicit branching.
- tl.load reads data from memory at the given address (pointer). When we call tl.load(x_ptr + offsets, mask=mask), Triton issues vectorized load instructions for those positions. For any element where the mask is false, the load is effectively skipped (or replaced with a placeholder value to avoid illegal memory access). The same logic applies to loading y.
- We then perform the elementwise addition result = x + y. Because Triton executes vectorized operations, this addition is applied to the entire block of elements at once. Conceptually, it resembles performing NumPy-style array addition on a slice of data, but here it runs in parallel within a GPU block.
- Finally, tl.store(output_ptr + offsets, result, mask=mask) writes the results for all valid indices back to global memory.
- Since each program instance processes BLOCK_SIZE elements and all instances run in parallel, the entire vector is added in a single kernel launch. ** Kernel Decorator:** the @triton.jit decorator is used to define a Triton kernel. The @triton.jit decorator works by traversing the abstract syntax tree (AST) of the provided Python function and dynamically generating Triton-IR using standard SSA construction algorithms.
② Triton Intermediate Representation(TTIR) When a Triton kernel is compiled, the compiler first traverses the abstract syntax tree (AST) of the decorated Python function. From this AST, Triton generates an intermediate representation called Triton-IR (TTIR). Triton-IR is a machine-independent, unoptimized representation of the kernel.
Triton Language → TTIR Workflow JIT Entry: Compiler Invocation When a Triton kernel is invoked for the first time, the @triton.jit decorator triggers just-in-time (JIT) compilation:
kernel = self.compile(...)
Here, self.compile launches the entire compilation process. A key step in this process is:
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
ast_to_ttir converts the Python AST into Triton’s intermediate representation (TTIR). AST → TTIR: Generating TTIR from Python AST (CodeGen Phase) Internally, ast_to_ttir traverses the Python AST using the visitor pattern:
ret = super().visit(node)
This line marks the critical entry point where the compiler visits each Python AST node and transforms it into a corresponding Triton IR node. Handling Binary Operations: BinOp → Addition Vector Addition Example For a vector addition operation:
output = x + y
Python’s AST generates a BinOp node. The visitor then enters:
visit_BinOp(self, node)
This function is defined in:
python/triton/language/core.py
Mapping to Triton Builtins (via Wrappers) The + operator is mapped, via a wrapper, to Triton’s builtin addition operation (add). The kernel calls the semantic layer to perform type-specific analysis The core logic resides in:
python/triton/language/semantic.py
For example, when adding two floats:
return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type)
This means:
- Triton tensor: created via the builder as an fadd node
- fadd: represents floating-point addition builder.create_fadd → C++ IR Construction On the C++ side, create_fadd corresponds to:
return self.create<arith::AddFOp>(lhs, rhs);
In MLIR terms, this generates an arith.addf operation. At this stage, the Python expression x + y has been fully transformed into a TTIR fadd instruction. Example: Triton IR for a Vector Addition Kernel For a Triton vector addition kernel, the generated TTIR snippet might look like this:
module {
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0)) attributes {noinline = false} {
%0 = tt.get_program_id x : i32 loc(#loc1)
%c1024_i32 = arith.constant 1024 : i32 loc(#loc2)
%1 = arith.muli %0, %c1024_i32 : i32 loc(#loc2)
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc3)
%3 = tt.splat %1 : i32 -> tensor<1024xi32> loc(#loc4)
%4 = arith.addi %3, %2 : tensor<1024xi32> loc(#loc4)
%5 = tt.splat %arg3 : i32 -> tensor<1024xi32> loc(#loc5)
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> loc(#loc5)
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc6)
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc6)
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc7)
%10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc8)
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc8)
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc9)
%13 = arith.addf %9, %12 : tensor<1024xf32> loc(#loc10)
%14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc11)
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc11)
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc12)
tt.return loc(#loc13)
} loc(#loc)
} loc(#loc)
③ Backend (Example: NVIDIA GPU) Triton uses an MLIR-based pass pipeline to automatically optimize Python kernels and map them to GPU-specific TTGIR. The compilation workflow proceeds as follows:
- TTGIR → LLVM IR → PTX → CUBIN
- TTGIR is progressively lowered to LLVM IR, then translated to PTX.
- The PTX code is compiled by ptxas into a CUBIN (GPU binary).
- Finally, the CUBIN is executed via the CUDA runtime on the GPU.
- The GPU architecture (Turing, Ampere, Hopper, etc.) influences the selection of MLIR passes. Triton leverages this information to automatically generate highly optimized code tailored for the target NVIDIA GPU. Spine-Triton: Why SpacemiT Build Triton Support? Although Triton offers many advantages, there are several limitations when targeting CPUs:
- Most Triton optimizations and semantics target GPU architectures
- The existing x86 Triton-CPU project provides limited performance
- GPU-style tiling, memory hierarchy, and threading models do not map efficiently to CPUs
- There is a lack of practical experience in CPU-adapted, unified-memory Triton kernel scheduling. As a result, almost all projects in the Triton ecosystem are ultimately limited to running on GPUs. SpacemiT aims to achieve something bigger: Enable Triton-written operators to run efficiently on RISC-V AI CPUs.